# -*- coding: utf-8 -*-
"""A general dialog agent."""
from typing import Optional, Union, Sequence, Any

from loguru import logger

from ..message import Msg
from .agent import AgentBase


class DialogAgent(AgentBase):
    """A simple agent used to perform a dialogue. Your can set its role by
    `sys_prompt`."""

    def __init__(
        self,
        name: str,
        sys_prompt: str,
        model_config_name: str,
        use_memory: bool = True,
        **kwargs: Any,
    ) -> None:
        """Initialize the dialog agent.

        Arguments:
            name (`str`):
                The name of the agent.
            sys_prompt (`Optional[str]`):
                The system prompt of the agent, which can be passed by args
                or hard-coded in the agent.
            model_config_name (`str`):
                The name of the model config, which is used to load model from
                configuration.
            use_memory (`bool`, defaults to `True`):
                Whether the agent has memory.
        """
        super().__init__(
            name=name,
            sys_prompt=sys_prompt,
            model_config_name=model_config_name,
            use_memory=use_memory,
        )

        if kwargs:
            logger.warning(
                f"Unused keyword arguments are provided: {kwargs}",
            )

    def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg:
        """Reply function of the agent. Processes the input data,
        generates a prompt using the current dialogue memory and system
        prompt, and invokes the language model to produce a response. The
        response is then formatted and added to the dialogue memory.

        Args:
            x (`Optional[Union[Msg, Sequence[Msg]]]`, defaults to `None`):
                The input message(s) to the agent, which also can be omitted if
                the agent doesn't need any input.

        Returns:
            `Msg`: The output message generated by the agent.
        """
        # record the input if needed
        if self.memory:
            self.memory.add(x)

        # prepare prompt
        prompt = self.model.format(
            Msg("system", self.sys_prompt, role="system"),
            self.memory
            and self.memory.get_memory()
            or x,  # type: ignore[arg-type]
        )

        # call llm and generate response
        response = self.model(prompt)

        # Print/speak the message in this agent's voice
        # Support both streaming and non-streaming responses by "or"
        self.speak(response.stream or response.text)

        msg = Msg(self.name, response.text, role="assistant")

        # Record the message in memory
        if self.memory:
            self.memory.add(msg)

        return msg
